import torch
import torch.multiprocessing as mp
import argparse
import os
import numpy as np
from tqdm import tqdm
import time
import csv 
import logging 

from environments.cvrp_env import CVRPEnv
from agents.a3c_agent import A3CAgent
from utils.data_generator import VRPDataGenerator

# 设置基本日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def main():
    parser = argparse.ArgumentParser(description="A3C算法求解CVRP问题")
    parser.add_argument('--num_workers', type=int, default=os.cpu_count(), help="A3C工作进程数量")
    parser.add_argument('--max_episodes', type=int, default=10000, help="总训练回合数")
    parser.add_argument('--lr', type=float, default=1e-5, help="学习率")
    parser.add_argument('--gamma', type=float, default=0.99, help="折扣因子")
    parser.add_argument('--embedding_dim', type=int, default=128, help="节点嵌入维度")
    parser.add_argument('--hidden_dim', type=int, default=128, help="LSTM/注意力机制的隐藏层维度")
    parser.add_argument('--value_loss_coeff', type=float, default=0.5, help="价值损失系数")
    parser.add_argument('--entropy_coeff', type=float, default=0.01, help="熵正则化系数")
    
    # 数据生成参数
    parser.add_argument('--min_nodes', type=int, default=10, help="最小客户节点数（不包括仓库）")
    parser.add_argument('--max_nodes', type=int, default=20, help="最大客户节点数（不包括仓库）")
    parser.add_argument('--min_capacity', type=float, default=20.0, help="最小车辆容量")
    parser.add_argument('--max_capacity', type=float, default=50.0, help="最大车辆容量")
    parser.add_argument('--min_demand', type=int, default=1, help="最小客户需求")
    parser.add_argument('--max_demand', type=int, default=10, help="最大客户需求")

    # CUDA设备管理
    parser.add_argument('--force_cpu', action='store_true', help='强制使用CPU，即使CUDA可用')

    # 评估参数
    parser.add_argument('--evaluate_episodes', type=int, default=100, help="评估实例数量")
    parser.add_argument('--eval_log_file', type=str, default='evaluation_summary.txt', help="评估结果保存路径")
    parser.add_argument('--save_eval_raw_data', action='store_true', help='将原始评估数据保存为CSV')
    parser.add_argument('--eval_raw_data_file', type=str, default='evaluation_raw_data.csv', help='原始评估数据CSV保存路径')

    # 训练日志间隔
    parser.add_argument('--log_interval', type=int, default=100, help="训练进度日志记录间隔")

    # 模型保存参数
    parser.add_argument('--save_model', action='store_true', help='训练后保存最终模型')
    parser.add_argument('--model_save_path', type=str, default='models/cvrp_model.pth', help='模型保存路径')

    # 评估模式
    parser.add_argument('--eval_only', action='store_true', help='仅运行评估，不进行训练')

    args = parser.parse_args()

    # CUDA设备检测
    if args.force_cpu:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    logging.info(f"使用设备: {device}")
    if device.type == 'cuda':
        try:
            print(f"CUDA设备名称: {torch.cuda.get_device_name(0)}")
        except torch.cuda.CudaError as e:
             logging.warning(f"无法获取CUDA设备名称: {e}。设备可能不可用。")
             logging.warning("由于CUDA错误，尝试切换到CPU。")
             device = torch.device("cpu")
             logging.info(f"已切换到设备: {device}")
        except Exception as e:
             logging.warning(f"检查CUDA设备时发生意外错误: {e}")
             logging.warning("由于意外错误，尝试切换到CPU。")
             device = torch.device("cpu")
             logging.info(f"已切换到设备: {device}")

    # 数据生成器（共享）
    data_generator = VRPDataGenerator(
        min_nodes=args.min_nodes, max_nodes=args.max_nodes,
        capacity_range=(args.min_capacity, args.max_capacity),
        demand_range=(args.min_demand, args.max_demand)
    )
    args.data_generator = data_generator

    # 初始环境配置
    env_config = {
        'num_nodes': args.max_nodes,
        'capacity': args.max_capacity,
        'depot_idx': 0,
        'input_dim': 4
    }

    # 设置多进程启动方法
    try:
        if mp.get_start_method(allow_none=True) is None:
             mp.set_start_method('spawn', force=True)
             logging.info("多进程启动方法设置为'spawn'。")
        else:
             logging.info(f"多进程启动方法已设置为'{mp.get_start_method()}'。")
    except RuntimeError as e:
        logging.warning(f"无法设置多进程启动方法: {e}")
        pass

    # 创建智能体
    agent = A3CAgent(env_config, args, device)
    
    # 加载预训练模型（可选）
    model_path = "models/cvrp_model.pth"
    if os.path.exists(model_path):
        logging.info(f"从{model_path}加载模型")
        try:
            agent.global_actor_critic.load_state_dict(torch.load(model_path, map_location=device))
            logging.info("模型加载成功。")
        except Exception as e:
            logging.error(f"加载模型时出错: {e}")
    else:
        logging.info("未找到预训练模型，从头开始训练。")

    # --- Training ---
    if not args.eval_only:
        logging.info("开始训练...")
        agent.train()
        logging.info("训练完成！")
        
        # 保存模型
        if args.save_model:
            os.makedirs("models", exist_ok=True)
            torch.save(agent.global_actor_critic.state_dict(), "models/cvrp_model.pth")
            logging.info("模型已保存到 models/cvrp_model.pth")
    
    # --- Evaluation ---
    if args.eval_only:
        logging.info("开始评估...")
        eval_results = agent.evaluate()
        
        # 保存评估结果
        if args.save_eval_raw_data:
            os.makedirs("results", exist_ok=True)
            with open("results/eval_results.csv", "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["问题规模", "平均路径长度", "标准差"])
                for size, (mean, std) in eval_results.items():
                    writer.writerow([size, mean, std])
            logging.info("评估结果已保存到 results/eval_results.csv")
        
        # 打印评估结果
        logging.info("\n评估结果:")
        distances = [r['distance'] for r in eval_results]
        vehicles = [r['vehicles'] for r in eval_results]
        times = [r['time'] for r in eval_results]

        logging.info(f"平均总距离: {np.mean(distances):.2f} ± {np.std(distances):.2f}")
        logging.info(f"平均使用车辆数: {np.mean(vehicles):.2f} ± {np.std(vehicles):.2f}")
        logging.info(f"平均计算时间: {np.mean(times):.4f}s")
    
    logging.info("程序执行完成！")

if __name__ == '__main__':
    main()